import yaml


def generate_data_yaml(output_path, dataset_root, train_rel, val_rel, class_names):
    data = {
        'path': dataset_root,
        'train': train_rel,
        'val': val_rel,
        'nc': len(class_names),
        'names': {i: name for i, name in enumerate(class_names)}
    }

    with open(output_path, 'w', encoding='utf-8') as f:
        yaml.dump(data, f, allow_unicode=True, sort_keys=False)

    print(f"[INFO] data.yaml 已保存至：{output_path}")


# 示例调用
class_names = [
    "BPSK", "QPSK", "8PSK", "OQPSK", "16QAM", "2ASK", "16APSK",
    "2FSK", "FM", "DSB-AM", "Single", "Chirp", "Comb"
]

generate_data_yaml(
    output_path=r"./data/data.yaml",
    dataset_root="dataset/513_249_30dB_2/",
    train_rel="train/images/images",
    val_rel="valid/images/images",
    class_names=class_names
)